from typing import TypedDict, Dict, List, Annotated, Union
from langgraph.constants import START, END
from langgraph.graph import add_messages, StateGraph
from EnvironmentService import EnvironmentService
import json
import operator
from langgraph.types import Send
from Node import TestClazz
from langchain_core.messages import HumanMessage, BaseMessage
import re
from Config import *
from PromptTemplate import Evaluator_Init_State_Template, Evaluation_Template, Evaluator_Analyzer_Template
from Generator import generator_graph
from Utils import validate_json_structure


class EvaluatorState(TypedDict):
    envServer: EnvironmentService  # Environment service
    package_name: str  # Package name
    method_id: int  # ID of the function under test
    method_code: str  # Code of the function under test
    method_signature: str  # Signature of the function under test
    start_line: int  # Start line of the function under test
    end_line: int  # End line of the function under test
    class_name: str  # Class name
    full_method_name: str  # Fully qualified method name
    method_summary: str  # Summary of the function's main functionality/purpose
    old_test_cases: List[Dict]  # List of old test cases
    evaluation_feedbacks: List[Dict]  # List of evaluation feedbacks
    test_cases: Annotated[List, operator.add]  # List of generated test cases
    final_test_case: str  # Final generated test case
    test_result: str  # Test result
    test_report: str  # Test report
    # Interaction history (for LLM processing)
    messages: Annotated[List[Union[dict, BaseMessage]], add_messages]


class SubState(TypedDict):
    envServer: EnvironmentService  # Environment service
    package_name: str  # Package name
    method_code: str  # Code of the function under test
    method_signature: str  # Signature of the function under test
    full_method_name: str  # Fully qualified method name
    start_line: int  # Start line of the function under test
    end_line: int  # End line of the function under test
    class_name: str  # Class name
    method_summary: str  # Summary of the function's main functionality/purpose
    test_case_to_update: str  # Test case to update
    requirement: Dict  # List of generated test requirements


def evaluator_init_state(state: EvaluatorState):
    init_prompt = Evaluator_Init_State_Template.invoke({})
    valid_prompt = init_prompt.to_messages()
    return {'messages': valid_prompt}


def test_case_evaluation(state: EvaluatorState):
    def extract_feedback(content: str):
        match = re.search(r'```json(.*?)```', content, re.DOTALL)
        return json.loads(match.group(1)) if match else None

    test_case_total_summary = []
    index = 0
    for test_case in state['old_test_cases']:
        assert isinstance(test_case, TestClazz)
        index += 1
        temp_prompt = f"## Test Case {index} Evaluation\n"
        temp_prompt += f"### Test Point\n{test_case.test_points}\n"
        temp_prompt += f"### Content\n{test_case.content}\n"
        temp_prompt += f"### Test Result\n{test_case.test_report}\n"
        temp_prompt += f"### Coverage Rate\n{test_case.coverage_rate}\n"
        temp_prompt += f"### Coverage Lines\n{test_case.coverage_lines}\n"
        temp_prompt += f"### Mutation Score\n{test_case.mutation_score}\n"
        mutants_str = "\n".join([f"Line: {mutant['Line']} - Description: {mutant['Description']} - Status: {mutant['Status']}" for mutant in test_case.mutants])
        temp_prompt += f"### Mutants\n{mutants_str}\n"
        temp_prompt += f"### Found Bugs\n{test_case.find_bugs}\n"
        test_summary = Evaluator_Analyzer_Template.invoke({'test_case_info': temp_prompt})
        test_case_total_summary.append(test_summary)

    test_case_total_prompt = "\n\n".join(test_case_total_summary)
    prompt = Evaluation_Template.invoke({'method_code': state['method_code'],
                                         'start_line': state['start_line'],
                                         'end_line': state['end_line'],
                                         'method_summary': state['method_summary'],
                                         'test_cases_summary': test_case_total_prompt})
    valid_prompt = prompt.to_messages()
    query_prompt = state["messages"]
    query_prompt.extend(valid_prompt)
    assert query_prompt is not None
    result = llm.invoke(query_prompt)
    state['messages'].append(result)
    evaluation_feedbacks = None
    try_times = 0
    max_try_times = 3
    while try_times < max_try_times:
        try:
            evaluation_feedbacks = extract_feedback(result.content)
        except json.JSONDecodeError:
            evaluation_feedbacks = None
        expected_schema = {
            "evaluations": list,
            "addition": list
        }

        evaluation_item_schema = {
            "test_case_id": int,
            "decision": str,
            "reason": str,
            "suggestion": str
        }

        addition_schema = {
            "description": str,
            "test_point": str
        }
        valid = validate_json_structure(evaluation_feedbacks, expected_schema)
        if valid:
            valid_1 = all(validate_json_structure(item, evaluation_item_schema) for item in evaluation_feedbacks["evaluations"])
            valid_2 = all(validate_json_structure(item, addition_schema) for item in evaluation_feedbacks["addition"])

        if valid and valid_1 and valid_2:
            valid_prompt.append(result)
            ret_feedbacks = []
            for item in evaluation_feedbacks["evaluations"]:
                item["test_case"] = state['old_test_cases'][item["test_case_id"] - 1].content
                ret_feedbacks.append(item)
            for item in evaluation_feedbacks["addition"]:
                item["test_case"] = ""
                item["decision"] = "add"
                ret_feedbacks.append(item)
            break
        else:
            try_times += 1
            if evaluation_feedbacks is None:
                feedback = HumanMessage(
                    "Can not find json data in the response, please provide the test points in the triple backticks."
                    "And make sure the format is following the JSON format."
                )
            else:
                feedback = HumanMessage(
                    "Your previous response does not match the required JSON format.\n\n"
                    "Please revise your response strictly according to this JSON schema:\n\n"
                    "```json\n"
                    "{\n"
                    "  \"evaluations\": [\n"
                    "    {\n"
                    "      \"test_case_id\": int,\n"
                    "      \"decision\": \"keep\" | \"update\" | \"discard\",\n"
                    "      \"reason\": \"Clearly explain your decision.\",\n"
                    "      \"suggestion\": \"Provide suggestions only if the decision is 'update', otherwise leave empty.\"\n"
                    "    }\n"
                    "  ],\n"
                    "  \"addition\": [\n"
                    "    {\n"
                    "      \"description\": \"Detailed description of any recommended additional test case.\",\n"
                    "      \"test_point\": \"Clearly state what scenario or edge case this test should cover.\"\n"
                    "    }\n"
                    "  ]\n"
                    "}\n"
                    "```\n\n"
                    "Do not include any additional content outside this structure."
                )
            state['messages'].append(feedback)
            result = llm.invoke(state['messages'])
            state["messages"].append(result)

    return {"messages": valid_prompt, "evaluation_feedbacks": ret_feedbacks[:5]}


def testCaseGenerator(state: SubState):
    response = generator_graph.invoke({
        "envServer": state["envServer"],
        "package_name": state["package_name"],
        "method_signature": state["method_signature"],
        "method_code": state["method_code"],
        "requirement": state["requirement"],
        "start_line": state["start_line"],
        "end_line": state["end_line"],
        "class_name": state["class_name"],
        "full_method_name": state["full_method_name"],
        "method_summary": state["method_summary"],
        "test_case": state["test_case_to_update"],
        "generate_or_evaluation": "evaluation"
    })
    test_report = {
        "test_result": response['test_result'],
        "find_bug": response['find_bug'],
        "requirement": response['requirement'],
        "test_case": response['test_case'],
        "test_report": response['test_report'],
        "coverage_report": response['coverage_report'],
        "mutation_report": response['mutation_report']
    }
    return {"test_cases": [test_report]}


def sendTestPoints(state: EvaluatorState):
    evaluation_feedbacks = state["evaluation_feedbacks"]
    if evaluation_feedbacks is None:
        return "test_case_evaluation"
    evaluation_feedbacks = [evaluation_feedback for evaluation_feedback in evaluation_feedbacks
                            if evaluation_feedback.get("decision") in ['update', 'add']]
    return [Send("testCaseGenerator", {"envServer": state["envServer"],
                                       "requirement": evaluation_feedback,
                                       "package_name": state["package_name"],
                                       "method_signature": state["method_signature"],
                                       "method_code": state["method_code"],
                                       "class_name": state["class_name"],
                                       "start_line": state["start_line"],
                                       "end_line": state["end_line"],
                                       "full_method_name": state["full_method_name"],
                                       "method_summary": state["method_summary"],
                                       "test_case_to_update": evaluation_feedback['test_case']}) for evaluation_feedback in evaluation_feedbacks]


def testCaseAcceptor(state: EvaluatorState):
    test_cases = state["test_cases"]
    return {"test_cases": test_cases}

evaluator_graph = StateGraph(EvaluatorState)

evaluator_graph.add_node("evaluator_init_state", evaluator_init_state)
evaluator_graph.add_node("test_case_evaluation", test_case_evaluation)
evaluator_graph.add_node("testCaseGenerator", testCaseGenerator)
evaluator_graph.add_node("testCaseAcceptor", testCaseAcceptor)

evaluator_graph.add_edge(START, "evaluator_init_state")
evaluator_graph.add_edge("evaluator_init_state", "test_case_evaluation")
evaluator_graph.add_conditional_edges("test_case_evaluation", sendTestPoints, ["testCaseGenerator"])
evaluator_graph.add_edge("testCaseGenerator", "testCaseAcceptor")
evaluator_graph.add_edge("testCaseAcceptor", END)

evaluator_graph = evaluator_graph.compile()

if __name__ == "__main__":
    try:
        # Generate the image data
        image_data = evaluator_graph.get_graph(xray=1).draw_mermaid_png()

        # Save the image data to a file
        with open("Evaluator.png", "wb") as f:
            f.write(image_data)
    except Exception as e:
        print(f"Failed to save image: {e}")
    print("Graph compiled successfully.")
